lubridate for working with dates in R and ggplot2library(tidyverse)
library(ggthemes)
library(knitr)
library(broom)
library(stringr)
library(modelr)
options(digits = 3)
set.seed(1234)
theme_set(theme_minimal())For more details on the
lubridatepackage, check out R for Data Science.
When importing data, date variables can be somewhat tricky to correctly store and utilize. In a spreadsheet, tabular format, dates by default will appear either as numeric (20174018) or string (2016-04-18, April 18th, 2017, etc.) columns. If you want to perform tasks such as: extracting and summarizing over individual components (year, month, day, etc.), we need to represent dates in a different, yet standardized, format.
lubridate is a tidyverse package that facilitates working with dates (and date-times) in R.
library(lubridate)##
## Attaching package: 'lubridate'
## The following object is masked from 'package:base':
##
## date
When using readr to import data files, R will use parse_date() or parse_datetime() to try and format any columns it thinks contain dates or date-times. To manually format dates from strings, use the appropriate function combining y, m, and d in the proper order depending on the original format of the date:
ymd("2017-01-31")## [1] "2017-01-31"
mdy("January 31st, 2017")## [1] "2017-01-31"
dmy("31-Jan-2017")## [1] "2017-01-31"
Let’s practice extracting components of dates using an example dataset. flights-departed.csv is a time series data file containing the daily number of departing commercial flights in the United States from 1988-2008.
(flights <- read_csv("data/flights-departed.csv"))## Parsed with column specification:
## cols(
## date = col_date(format = ""),
## value = col_integer()
## )
## # A tibble: 7,671 × 2
## date value
## <date> <int>
## 1 1988-01-01 12681
## 2 1988-01-02 13264
## 3 1988-01-03 13953
## 4 1988-01-04 13921
## 5 1988-01-05 13932
## 6 1988-01-06 13157
## 7 1988-01-07 11159
## 8 1988-01-08 11631
## 9 1988-01-09 12045
## 10 1988-01-10 13160
## # ... with 7,661 more rows
We will use ggplot2 to generate several graphs based on the data. The first is a simple line plot over time of the daily commercial flights. To build this, we don’t need to modify flights:
ggplot(flights, aes(date, value)) +
geom_line() +
labs(x = NULL,
y = "Number of departing commercial flights")But this is quite noisy. Instead, let’s draw a line plot depicting commercial flights over a one-year period, with separate lines for each year in the data (1988, 1989, 1990, etc.). To do that, we need to create a new variable year which will serve as our grouping variable in ggplot():
(flights <- flights %>%
mutate(year = year(date),
yday = yday(date),
# hack to label the x-axis with months
days = dmy(format(date,"%d-%m-2016"))))## # A tibble: 7,671 × 5
## date value year yday days
## <date> <int> <dbl> <dbl> <date>
## 1 1988-01-01 12681 1988 1 2016-01-01
## 2 1988-01-02 13264 1988 2 2016-01-02
## 3 1988-01-03 13953 1988 3 2016-01-03
## 4 1988-01-04 13921 1988 4 2016-01-04
## 5 1988-01-05 13932 1988 5 2016-01-05
## 6 1988-01-06 13157 1988 6 2016-01-06
## 7 1988-01-07 11159 1988 7 2016-01-07
## 8 1988-01-08 11631 1988 8 2016-01-08
## 9 1988-01-09 12045 1988 9 2016-01-09
## 10 1988-01-10 13160 1988 10 2016-01-10
## # ... with 7,661 more rows
ggplot(flights, aes(days, value)) +
geom_line(aes(group = year), alpha = .2) +
geom_smooth(se = FALSE) +
scale_x_date(labels = scales::date_format("%b")) +
labs(x = NULL,
y = "Number of departing commercial flights")## `geom_smooth()` using method = 'gam'
Or we could summarize the distribution of departing commercial flights by days in each month over the 20 year time period:
(flights <- flights %>%
mutate(month = month(date, label = TRUE)))## # A tibble: 7,671 × 6
## date value year yday days month
## <date> <int> <dbl> <dbl> <date> <ord>
## 1 1988-01-01 12681 1988 1 2016-01-01 Jan
## 2 1988-01-02 13264 1988 2 2016-01-02 Jan
## 3 1988-01-03 13953 1988 3 2016-01-03 Jan
## 4 1988-01-04 13921 1988 4 2016-01-04 Jan
## 5 1988-01-05 13932 1988 5 2016-01-05 Jan
## 6 1988-01-06 13157 1988 6 2016-01-06 Jan
## 7 1988-01-07 11159 1988 7 2016-01-07 Jan
## 8 1988-01-08 11631 1988 8 2016-01-08 Jan
## 9 1988-01-09 12045 1988 9 2016-01-09 Jan
## 10 1988-01-10 13160 1988 10 2016-01-10 Jan
## # ... with 7,661 more rows
ggplot(flights, aes(month, value)) +
geom_violin() +
geom_boxplot(width = .1, outlier.shape = NA) +
labs(x = NULL,
y = "Number of departing commercial flights")Hmmm, there seems to be an outlier in September. What’s up with that?
Finally, we can generate a heatmap depicting the change over time of this data by creating a calendar-like visualization.1 In order do this, we need the following grammar for the graph:
value (number of departing flights)identitygeom_tile()facet_grid() (year X month)In order to generate this graph then, we need to create several new variables for flights:
We can use lubridate to directly generate three of those variables (we’ve already generated year and month):
(flights <- flights %>%
mutate(weekday = wday(date, label = TRUE)))## # A tibble: 7,671 × 7
## date value year yday days month weekday
## <date> <int> <dbl> <dbl> <date> <ord> <ord>
## 1 1988-01-01 12681 1988 1 2016-01-01 Jan Fri
## 2 1988-01-02 13264 1988 2 2016-01-02 Jan Sat
## 3 1988-01-03 13953 1988 3 2016-01-03 Jan Sun
## 4 1988-01-04 13921 1988 4 2016-01-04 Jan Mon
## 5 1988-01-05 13932 1988 5 2016-01-05 Jan Tues
## 6 1988-01-06 13157 1988 6 2016-01-06 Jan Wed
## 7 1988-01-07 11159 1988 7 2016-01-07 Jan Thurs
## 8 1988-01-08 11631 1988 8 2016-01-08 Jan Fri
## 9 1988-01-09 12045 1988 9 2016-01-09 Jan Sat
## 10 1988-01-10 13160 1988 10 2016-01-10 Jan Sun
## # ... with 7,661 more rows
We use label = TRUE to generate factor labels for these values (January, February, March) instead of the numeric equivalent (1, 2, 3).
To generate the final week-in-month variable, we need to combine a few lubridate functions to get exactly what we want:
(flights <- flights %>%
# generate variables for week in the year (1-54) and the day in the year (1-366)
mutate(week = week(date),
yday = yday(date)) %>%
# normalize to draw calendar correctly - wday should represent the number of days from the Sunday of the week containing January 1st, then adjust based on that
group_by(year) %>%
mutate(yday = yday + wday(date)[1] - 2,
week = floor(yday / 7)) %>%
group_by(year, month) %>%
mutate(week_month = week - min(week) + 1))## Source: local data frame [7,671 x 9]
## Groups: year, month [252]
##
## date value year yday days month weekday week week_month
## <date> <int> <dbl> <dbl> <date> <ord> <ord> <dbl> <dbl>
## 1 1988-01-01 12681 1988 5 2016-01-01 Jan Fri 0 1
## 2 1988-01-02 13264 1988 6 2016-01-02 Jan Sat 0 1
## 3 1988-01-03 13953 1988 7 2016-01-03 Jan Sun 1 2
## 4 1988-01-04 13921 1988 8 2016-01-04 Jan Mon 1 2
## 5 1988-01-05 13932 1988 9 2016-01-05 Jan Tues 1 2
## 6 1988-01-06 13157 1988 10 2016-01-06 Jan Wed 1 2
## 7 1988-01-07 11159 1988 11 2016-01-07 Jan Thurs 1 2
## 8 1988-01-08 11631 1988 12 2016-01-08 Jan Fri 1 2
## 9 1988-01-09 12045 1988 13 2016-01-09 Jan Sat 1 2
## 10 1988-01-10 13160 1988 14 2016-01-10 Jan Sun 2 3
## # ... with 7,661 more rows
Now that we have the data correctly formatted and all the components are extracted, we can draw the graph:
ggplot(flights, aes(weekday, week_month, fill = value)) +
facet_grid(year ~ month) +
geom_tile(color = "black") +
scale_fill_continuous(low = "green", high = "red") +
scale_x_discrete(labels = NULL) +
scale_y_reverse(labels = NULL) +
labs(title = "Domestic commercial flight activity",
x = NULL,
y = NULL,
fill = "Number of departing flights") +
theme_void() +
theme(legend.position = "bottom",
legend.text = element_text(angle = 45))Aha, now the outlier makes sense. In the days following the September 11th attacks, the United States grounded virtually all commercial air traffic.
When examining multivariate continuous data, scatterplots are a quick and easy visualization to assess relationships. However if the data points become too densely clustered, interpreting the graph becomes difficult. Consider the diamonds dataset:
p <- ggplot(diamonds, aes(carat, price)) +
geom_point() +
scale_y_continuous(labels = scales::dollar) +
labs(x = "Carat size",
y = "Price")
pWhat is the relationship between carat size and price? It appears positive, but there are also a lot of densely packed data points in the middle of the graph. Smoothing lines are a method for summarizing the relationship between variables to capture important patterns by approximating the functional form of the relationship. The functional form can take on many shapes. For instance, a very common functional form is a best-fit line, also known as ordinary least squares (OLS) or simple linear regression. We can estimate the model directly using lm(), or we can directly plot the line by using geom_smooth(method = "lm"):
p +
geom_smooth(method = "lm", se = FALSE)The downside to a linear best-fit line is that it assumes the relationship between the variables is additive and monotonic. Therefore the summarized relationship between carat size and price seems wildly incorrect for diamonds with a carat size larger than 3. Instead we could use a generalized additive model which allow for flexible, non-linear relationships between the variables while still implementing a basic regression approach:2
p +
geom_smooth(se = FALSE)## `geom_smooth()` using method = 'gam'
Locally weighted scatterplot smoothing (local regression, LOWESS, or LOESS) fits a separate non-linear function at each target point \(x_0\) using only the nearby training observations. This method estimates a regression line based on localized subsets of the data, building up the global function \(f\) point-by-point.
Here is an example of a local linear regression on the ethanol dataset in the lattice package:
The LOESS is built up point-by-point:
One important argument you can control with LOESS is the span, or how smooth the LOESS function will become. A larger span will result in a smoother curve, but may not be as accurate. A smaller span will be more local and wiggly, but improve our fit to the training data.
LOESS lines are best used for datasets with fewer than 1000 observations, otherwise the time and memory usage required to compute the line increases exponentially.
r_plot <- function(r, n = 100){
xy <- ecodist::corgen(len = n, r = r) %>%
bind_cols
ggplot(xy, aes(x, y)) +
geom_point() +
ggtitle(str_c("Pearson's r = ", r))
}
r <- c(.8, 0, -.8)
for(r in r){
print(r_plot(r))
}To quickly visualize several variables in a dataset and their relation to one another, a scatterplot matrix is a quick and detailed tool for generating a series of scatterplots for each combination of variables. Consider credit.csv which contains a sample of individuals from a credit card company, identifying their total amount of credit card debt and other financial/demographic variables:
credit <- read_csv("data/Credit.csv") %>%
# remove first ID column
select(-X1)## Warning: Missing column names filled in: 'X1' [1]
## Parsed with column specification:
## cols(
## X1 = col_integer(),
## Income = col_double(),
## Limit = col_integer(),
## Rating = col_integer(),
## Cards = col_integer(),
## Age = col_integer(),
## Education = col_integer(),
## Gender = col_character(),
## Student = col_character(),
## Married = col_character(),
## Ethnicity = col_character(),
## Balance = col_integer()
## )
names(credit) <- stringr::str_to_lower(names(credit)) # convert column names to lowercase
str(credit)## Classes 'tbl_df', 'tbl' and 'data.frame': 400 obs. of 11 variables:
## $ income : num 14.9 106 104.6 148.9 55.9 ...
## $ limit : int 3606 6645 7075 9504 4897 8047 3388 7114 3300 6819 ...
## $ rating : int 283 483 514 681 357 569 259 512 266 491 ...
## $ cards : int 2 3 4 3 2 4 2 2 5 3 ...
## $ age : int 34 82 71 36 68 77 37 87 66 41 ...
## $ education: int 11 15 11 11 16 10 12 9 13 19 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ student : chr "No" "Yes" "No" "No" ...
## $ married : chr "Yes" "Yes" "No" "No" ...
## $ ethnicity: chr "Caucasian" "Asian" "Asian" "Asian" ...
## $ balance : int 333 903 580 964 331 1151 203 872 279 1350 ...
## - attr(*, "spec")=List of 2
## ..$ cols :List of 12
## .. ..$ X1 : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ Income : list()
## .. .. ..- attr(*, "class")= chr "collector_double" "collector"
## .. ..$ Limit : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ Rating : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ Cards : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ Age : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ Education: list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ Gender : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ Student : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ Married : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ Ethnicity: list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ Balance : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## ..$ default: list()
## .. ..- attr(*, "class")= chr "collector_guess" "collector"
## ..- attr(*, "class")= chr "col_spec"
If we want to quickly assess the relationship between all of the variables (in preparation for more advanced statistical learning techniques), we could generate a matrix of scatterplots using the base graphics package:
pairs(select_if(credit, is.numeric))select_if())ggplot2 so it’s hard to modify using techniques with which we are already familiarInstead, we can use GGally::ggpairs() to generate a scatterplot matrix. GGally is a package for R that extends ggplot2 by adding helper functions for common multivariate data structures. ggpairs() is a function that allows us to quickly generate a scatterplot matrix.
library(GGally)
ggpairs(select_if(credit, is.numeric))When applied to strictly numeric variables, the lower triangle generates scatterplots, the upper triangle prints the correlation coefficient, and the diagonal panels are density plots of the variable.
Because ggpairs() is ultimately based on ggplot(), we can use the same types of commands to modify the graph. For instance, if we want to use the color aesthetic to distinguish between men and women in the dataset:
ggpairs(credit, mapping = aes(color = gender),
columns = c("income", "limit", "rating", "cards", "age", "education", "balance"))Or if we wanted to draw a smoothing line instead of scatterplots, we can modify the graph’s matrix sections:
ggpairs(select_if(credit, is.numeric),
lower = list(
continuous = "smooth"
)
)Hmm, too difficult to see the smoothers because the points are so dense. We can use wrap() to pass through individual parameters to the underlying geom_():
ggpairs(select_if(credit, is.numeric),
lower = list(
continuous = wrap("smooth", alpha = .1, color = "blue")
)
)Or we can write a custom function and apply it to the lower triangle panels:
scatter_smooth <- function(data, mapping, ...) {
ggplot(data = data, mapping = mapping) +
# make data points transparent
geom_point(alpha = .2) +
# add default smoother
geom_smooth(se = FALSE)
}
ggpairs(select_if(credit, is.numeric),
lower = list(
continuous = scatter_smooth
)
)ggpairs(credit, mapping = aes(color = gender),
columns = c("income", "limit", "rating", "cards", "age", "education", "balance"),
lower = list(
continuous = scatter_smooth
)
)ggpairs() also works on datasets with a mix of qualitative and quantitative variables, drawing appropriate graphs based on whether the variables are continuous or discrete:
ggpairs(select(rcfss::scorecard, type:debt))## Warning: Removed 21 rows containing non-finite values (stat_boxplot).
## Warning: Removed 471 rows containing non-finite values (stat_boxplot).
## Warning: Removed 22 rows containing non-finite values (stat_boxplot).
## Warning: Removed 1 rows containing non-finite values (stat_boxplot).
## Warning: Removed 69 rows containing non-finite values (stat_boxplot).
## Warning: Removed 124 rows containing non-finite values (stat_boxplot).
## Warning: Removed 74 rows containing non-finite values (stat_boxplot).
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 21 rows containing non-finite values (stat_bin).
## Warning: Removed 21 rows containing non-finite values (stat_density).
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 21 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 474 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 42 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 21 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 82 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 136 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 87 rows containing missing values
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 21 rows containing missing values (geom_point).
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 471 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 22 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removing 1 row that contained a missing value
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 69 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 124 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 74 rows containing missing values
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 471 rows containing non-finite values (stat_bin).
## Warning: Removed 474 rows containing missing values (geom_point).
## Warning: Removed 471 rows containing missing values (geom_point).
## Warning: Removed 471 rows containing non-finite values (stat_density).
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 474 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 471 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 479 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 508 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 476 rows containing missing values
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 22 rows containing non-finite values (stat_bin).
## Warning: Removed 42 rows containing missing values (geom_point).
## Warning: Removed 22 rows containing missing values (geom_point).
## Warning: Removed 474 rows containing missing values (geom_point).
## Warning: Removed 22 rows containing non-finite values (stat_density).
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 23 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 83 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 142 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 93 rows containing missing values
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 1 rows containing non-finite values (stat_bin).
## Warning: Removed 21 rows containing missing values (geom_point).
## Warning: Removed 1 rows containing missing values (geom_point).
## Warning: Removed 471 rows containing missing values (geom_point).
## Warning: Removed 23 rows containing missing values (geom_point).
## Warning: Removed 1 rows containing non-finite values (stat_density).
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 69 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 125 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 75 rows containing missing values
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 69 rows containing non-finite values (stat_bin).
## Warning: Removed 82 rows containing missing values (geom_point).
## Warning: Removed 69 rows containing missing values (geom_point).
## Warning: Removed 479 rows containing missing values (geom_point).
## Warning: Removed 83 rows containing missing values (geom_point).
## Warning: Removed 69 rows containing missing values (geom_point).
## Warning: Removed 69 rows containing non-finite values (stat_density).
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 185 rows containing missing values
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 135 rows containing missing values
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 124 rows containing non-finite values (stat_bin).
## Warning: Removed 136 rows containing missing values (geom_point).
## Warning: Removed 124 rows containing missing values (geom_point).
## Warning: Removed 508 rows containing missing values (geom_point).
## Warning: Removed 142 rows containing missing values (geom_point).
## Warning: Removed 125 rows containing missing values (geom_point).
## Warning: Removed 185 rows containing missing values (geom_point).
## Warning: Removed 124 rows containing non-finite values (stat_density).
## Warning in (function (data, mapping, alignPercent = 0.6, method =
## "pearson", : Removed 138 rows containing missing values
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 74 rows containing non-finite values (stat_bin).
## Warning: Removed 87 rows containing missing values (geom_point).
## Warning: Removed 74 rows containing missing values (geom_point).
## Warning: Removed 476 rows containing missing values (geom_point).
## Warning: Removed 93 rows containing missing values (geom_point).
## Warning: Removed 75 rows containing missing values (geom_point).
## Warning: Removed 135 rows containing missing values (geom_point).
## Warning: Removed 138 rows containing missing values (geom_point).
## Warning: Removed 74 rows containing non-finite values (stat_density).
Scatterplot matricies can provide lots of information, but can also be very densely packed. Perhaps instead we want to quickly visualize the correlation between each of the variables.3 We can easily calculate the correlation coefficients using cor():
(mpg_lite <- select_if(mpg, is.numeric))## # A tibble: 234 × 5
## displ year cyl cty hwy
## <dbl> <int> <int> <int> <int>
## 1 1.8 1999 4 18 29
## 2 1.8 1999 4 21 29
## 3 2.0 2008 4 20 31
## 4 2.0 2008 4 21 30
## 5 2.8 1999 6 16 26
## 6 2.8 1999 6 18 26
## 7 3.1 2008 6 18 27
## 8 1.8 1999 4 18 26
## 9 1.8 1999 4 16 25
## 10 2.0 2008 4 20 28
## # ... with 224 more rows
(cormat <- mpg_lite %>%
cor %>%
round(2))## displ year cyl cty hwy
## displ 1.00 0.15 0.93 -0.80 -0.77
## year 0.15 1.00 0.12 -0.04 0.00
## cyl 0.93 0.12 1.00 -0.81 -0.76
## cty -0.80 -0.04 -0.81 1.00 0.96
## hwy -0.77 0.00 -0.76 0.96 1.00
But who likes yucky tables. Instead let’s turn this into a heatmap. First we need to reshape the data into a tidy structure:
What we need is a data frame with three columns:
We can use reshape2::melt() to quickly accomplish this:
library(reshape2)##
## Attaching package: 'reshape2'
## The following object is masked from 'package:tidyr':
##
## smiths
(melted_cormat <- melt(cormat))## Var1 Var2 value
## 1 displ displ 1.00
## 2 year displ 0.15
## 3 cyl displ 0.93
## 4 cty displ -0.80
## 5 hwy displ -0.77
## 6 displ year 0.15
## 7 year year 1.00
## 8 cyl year 0.12
## 9 cty year -0.04
## 10 hwy year 0.00
## 11 displ cyl 0.93
## 12 year cyl 0.12
## 13 cyl cyl 1.00
## 14 cty cyl -0.81
## 15 hwy cyl -0.76
## 16 displ cty -0.80
## 17 year cty -0.04
## 18 cyl cty -0.81
## 19 cty cty 1.00
## 20 hwy cty 0.96
## 21 displ hwy -0.77
## 22 year hwy 0.00
## 23 cyl hwy -0.76
## 24 cty hwy 0.96
## 25 hwy hwy 1.00
We can then use geom_tile() to visualize the correlation matrix:
ggplot(melted_cormat, aes(x = Var1, y = Var2, fill = value)) +
geom_tile()Not exactly pretty. We can clean it up first by reducing redundancy (remember the upper and lower triangles provide duplicate information):
# Get lower triangle of the correlation matrix
get_lower_tri<-function(cormat){
cormat[upper.tri(cormat)] <- NA
return(cormat)
}
# Get upper triangle of the correlation matrix
get_upper_tri <- function(cormat){
cormat[lower.tri(cormat)]<- NA
return(cormat)
}
upper_tri <- get_upper_tri(cormat)
upper_tri## displ year cyl cty hwy
## displ 1 0.15 0.93 -0.80 -0.77
## year NA 1.00 0.12 -0.04 0.00
## cyl NA NA 1.00 -0.81 -0.76
## cty NA NA NA 1.00 0.96
## hwy NA NA NA NA 1.00
Now melt upper_tri and repeat the same process, cleaning up the colors for the heatmap as well to distinguish between positive and negative coefficients:
melted_cormat <- melt(upper_tri, na.rm = TRUE)
ggplot(melted_cormat, aes(Var2, Var1, fill = value))+
geom_tile(color = "white") +
scale_fill_gradient2(low = "blue", high = "red", mid = "white",
midpoint = 0, limit = c(-1,1), space = "Lab",
name="Pearson\nCorrelation") +
theme(axis.text.x = element_text(angle = 45, vjust = 1,
size = 12, hjust = 1)) +
coord_fixed()We can also reorder the correlation matrix according to correlation coefficient to help reveal additional trends:
reorder_cormat <- function(cormat){
# Use correlation between variables as distance
dd <- as.dist((1-cormat)/2)
hc <- hclust(dd)
cormat <-cormat[hc$order, hc$order]
}
# Reorder the correlation matrix
cormat <- reorder_cormat(cormat)
upper_tri <- get_upper_tri(cormat)
# Melt the correlation matrix
melted_cormat <- melt(upper_tri, na.rm = TRUE)
# Create a ggheatmap
ggheatmap <- ggplot(melted_cormat, aes(Var2, Var1, fill = value))+
geom_tile(color = "white")+
scale_fill_gradient2(low = "blue", high = "red", mid = "white",
midpoint = 0, limit = c(-1,1), space = "Lab",
name="Pearson\nCorrelation") +
theme_minimal()+ # minimal theme
theme(axis.text.x = element_text(angle = 45, vjust = 1,
size = 12, hjust = 1))+
coord_fixed()
# Print the heatmap
print(ggheatmap)Finally we can directly label the correlation coefficient values on the graph, so we have both the color channel and exact values:
ggheatmap +
geom_text(aes(Var2, Var1, label = value), color = "black", size = 4) +
theme(
axis.title.x = element_blank(),
axis.title.y = element_blank(),
panel.grid.major = element_blank(),
panel.border = element_blank(),
panel.background = element_blank(),
axis.ticks = element_blank(),
legend.position = "bottom")To make it more flexible, we can also turn all of this into a function that works for any dataset:
cormat_heatmap <- function(data){
# generate correlation matrix
cormat <- round(cor(data), 2)
# melt into a tidy table
get_upper_tri <- function(cormat){
cormat[lower.tri(cormat)]<- NA
return(cormat)
}
upper_tri <- get_upper_tri(cormat)
# reorder matrix based on coefficient value
reorder_cormat <- function(cormat){
# Use correlation between variables as distance
dd <- as.dist((1-cormat)/2)
hc <- hclust(dd)
cormat <-cormat[hc$order, hc$order]
}
cormat <- reorder_cormat(cormat)
upper_tri <- get_upper_tri(cormat)
# Melt the correlation matrix
melted_cormat <- melt(upper_tri, na.rm = TRUE)
# Create a ggheatmap
ggheatmap <- ggplot(melted_cormat, aes(Var2, Var1, fill = value))+
geom_tile(color = "white")+
scale_fill_gradient2(low = "blue", high = "red", mid = "white",
midpoint = 0, limit = c(-1,1), space = "Lab",
name="Pearson\nCorrelation") +
theme_minimal()+ # minimal theme
theme(axis.text.x = element_text(angle = 45, vjust = 1,
size = 12, hjust = 1))+
coord_fixed()
# add correlation values to graph
ggheatmap +
geom_text(aes(Var2, Var1, label = value), color = "black", size = 4) +
theme(
axis.title.x = element_blank(),
axis.title.y = element_blank(),
panel.grid.major = element_blank(),
panel.border = element_blank(),
panel.background = element_blank(),
axis.ticks = element_blank(),
legend.position = "bottom")
}
cormat_heatmap(select_if(mpg, is.numeric))cormat_heatmap(select_if(credit, is.numeric))cormat_heatmap(select_if(diamonds, is.numeric))stat_smooth() methodsdevtools::session_info()## Session info --------------------------------------------------------------
## setting value
## version R version 3.3.3 (2017-03-06)
## system x86_64, darwin13.4.0
## ui X11
## language (EN)
## collate en_US.UTF-8
## tz America/Chicago
## date 2017-04-19
## Packages ------------------------------------------------------------------
## package * version date source
## assertthat 0.1 2013-12-06 CRAN (R 3.3.0)
## backports 1.0.5 2017-01-18 CRAN (R 3.3.2)
## broom * 0.4.2 2017-02-13 CRAN (R 3.3.2)
## codetools 0.2-15 2016-10-05 CRAN (R 3.3.3)
## colorspace 1.3-2 2016-12-14 CRAN (R 3.3.2)
## DBI 0.6 2017-03-09 CRAN (R 3.3.3)
## devtools 1.12.0 2016-06-24 CRAN (R 3.3.0)
## digest 0.6.12 2017-01-27 CRAN (R 3.3.2)
## dplyr * 0.5.0 2016-06-24 CRAN (R 3.3.0)
## evaluate 0.10 2016-10-11 CRAN (R 3.3.0)
## forcats 0.2.0 2017-01-23 CRAN (R 3.3.2)
## foreign 0.8-67 2016-09-13 CRAN (R 3.3.3)
## GGally * 1.3.0 2016-11-13 CRAN (R 3.3.2)
## gganimate * 0.1 2016-11-11 Github (dgrtwo/gganimate@26ec501)
## ggplot2 * 2.2.1 2016-12-30 CRAN (R 3.3.2)
## ggthemes * 3.4.0 2017-02-19 CRAN (R 3.3.2)
## gtable 0.2.0 2016-02-26 CRAN (R 3.3.0)
## haven 1.0.0 2016-09-23 cran (@1.0.0)
## hms 0.3 2016-11-22 CRAN (R 3.3.2)
## htmltools 0.3.5 2016-03-21 CRAN (R 3.3.0)
## httr 1.2.1 2016-07-03 CRAN (R 3.3.0)
## jsonlite 1.3 2017-02-28 CRAN (R 3.3.2)
## knitr * 1.15.1 2016-11-22 cran (@1.15.1)
## labeling 0.3 2014-08-23 CRAN (R 3.3.0)
## lattice * 0.20-34 2016-09-06 CRAN (R 3.3.3)
## lazyeval 0.2.0 2016-06-12 CRAN (R 3.3.0)
## lubridate * 1.6.0 2016-09-13 CRAN (R 3.3.0)
## magrittr 1.5 2014-11-22 CRAN (R 3.3.0)
## memoise 1.0.0 2016-01-29 CRAN (R 3.3.0)
## mnormt 1.5-5 2016-10-15 CRAN (R 3.3.0)
## modelr * 0.1.0 2016-08-31 CRAN (R 3.3.0)
## munsell 0.4.3 2016-02-13 CRAN (R 3.3.0)
## nlme 3.1-131 2017-02-06 CRAN (R 3.3.3)
## plyr 1.8.4 2016-06-08 CRAN (R 3.3.0)
## psych 1.7.3.21 2017-03-22 CRAN (R 3.3.2)
## purrr * 0.2.2 2016-06-18 CRAN (R 3.3.0)
## R6 2.2.0 2016-10-05 CRAN (R 3.3.0)
## RColorBrewer 1.1-2 2014-12-07 CRAN (R 3.3.0)
## Rcpp 0.12.10 2017-03-19 cran (@0.12.10)
## readr * 1.1.0 2017-03-22 cran (@1.1.0)
## readxl 0.1.1 2016-03-28 CRAN (R 3.3.0)
## reshape 0.8.6 2016-10-21 CRAN (R 3.3.0)
## reshape2 * 1.4.2 2016-10-22 CRAN (R 3.3.0)
## rmarkdown 1.3 2016-12-21 CRAN (R 3.3.2)
## rprojroot 1.2 2017-01-16 CRAN (R 3.3.2)
## rvest 0.3.2 2016-06-17 CRAN (R 3.3.0)
## scales 0.4.1 2016-11-09 CRAN (R 3.3.1)
## stringi 1.1.2 2016-10-01 CRAN (R 3.3.0)
## stringr * 1.2.0 2017-02-18 CRAN (R 3.3.2)
## tibble * 1.2 2016-08-26 cran (@1.2)
## tidyr * 0.6.1 2017-01-10 CRAN (R 3.3.2)
## tidyverse * 1.1.1 2017-01-27 CRAN (R 3.3.2)
## withr 1.0.2 2016-06-20 CRAN (R 3.3.0)
## xml2 1.1.1 2017-01-24 CRAN (R 3.3.2)
## yaml 2.1.14 2016-11-12 cran (@2.1.14)
geom_smooth() automatically implements the gam method for datasets with greater than 1000 observations.↩
Example drawn from ggplot2 : Quick correlation matrix heatmap - R software and data visualization.↩